# testing/suite/test_results.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors

import datetime
import re

from .. import engines
from .. import fixtures
from ..assertions import eq_
from ..config import requirements
from ..schema import Column
from ..schema import Table
from ... import DateTime
from ... import func
from ... import Integer
from ... import quoted_name
from ... import select
from ... import sql
from ... import String
from ... import testing
from ... import text


class RowFetchTest(fixtures.TablesTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            "plain_pk",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("data", String(50)),
        )
        Table(
            "has_dates",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("today", DateTime),
        )

    @classmethod
    def insert_data(cls, connection):
        connection.execute(
            cls.tables.plain_pk.insert(),
            [
                {"id": 1, "data": "d1"},
                {"id": 2, "data": "d2"},
                {"id": 3, "data": "d3"},
            ],
        )

        connection.execute(
            cls.tables.has_dates.insert(),
            [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
        )

    def test_via_attr(self, connection):
        row = connection.execute(
            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
        ).first()

        eq_(row.id, 1)
        eq_(row.data, "d1")

    def test_via_string(self, connection):
        row = connection.execute(
            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
        ).first()

        eq_(row._mapping["id"], 1)
        eq_(row._mapping["data"], "d1")

    def test_via_int(self, connection):
        row = connection.execute(
            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
        ).first()

        eq_(row[0], 1)
        eq_(row[1], "d1")

    def test_via_col_object(self, connection):
        row = connection.execute(
            self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
        ).first()

        eq_(row._mapping[self.tables.plain_pk.c.id], 1)
        eq_(row._mapping[self.tables.plain_pk.c.data], "d1")

    @requirements.duplicate_names_in_cursor_description
    def test_row_with_dupe_names(self, connection):
        result = connection.execute(
            select(
                self.tables.plain_pk.c.data,
                self.tables.plain_pk.c.data.label("data"),
            ).order_by(self.tables.plain_pk.c.id)
        )
        row = result.first()
        eq_(result.keys(), ["data", "data"])
        eq_(row, ("d1", "d1"))

    def test_row_w_scalar_select(self, connection):
        """test that a scalar select as a column is returned as such
        and that type conversion works OK.

        (this is half a SQLAlchemy Core test and half to catch database
        backends that may have unusual behavior with scalar selects.)

        """
        datetable = self.tables.has_dates
        s = select(datetable.alias("x").c.today).scalar_subquery()
        s2 = select(datetable.c.id, s.label("somelabel"))
        row = connection.execute(s2).first()

        eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))


class NameDenormalizeTest(fixtures.TablesTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        cls.tables.denormalize_table = Table(
            "denormalize_table",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("all_lowercase", Integer),
            Column("ALL_UPPERCASE", Integer),
            Column("MixedCase", Integer),
            Column(quoted_name("all_lowercase_quoted", quote=True), Integer),
            Column(quoted_name("ALL_UPPERCASE_QUOTED", quote=True), Integer),
        )

    @classmethod
    def insert_data(cls, connection):
        connection.execute(
            cls.tables.denormalize_table.insert(),
            {
                "id": 1,
                "all_lowercase": 5,
                "ALL_UPPERCASE": 6,
                "MixedCase": 7,
                "all_lowercase_quoted": 8,
                "ALL_UPPERCASE_QUOTED": 9,
            },
        )

    def _assert_row_mapping(self, row, mapping, include_cols=None):
        eq_(row._mapping, mapping)

        for k in mapping:
            eq_(row._mapping[k], mapping[k])
            eq_(getattr(row, k), mapping[k])

        for idx, k in enumerate(mapping):
            eq_(row[idx], mapping[k])

        if include_cols:
            for col, (idx, k) in zip(include_cols, enumerate(mapping)):
                eq_(row._mapping[col], mapping[k])

    @testing.variation(
        "stmt_type", ["driver_sql", "text_star", "core_select", "text_cols"]
    )
    @testing.variation("use_driver_cols", [True, False])
    def test_cols_driver_cols(self, connection, stmt_type, use_driver_cols):
        if stmt_type.driver_sql or stmt_type.text_star or stmt_type.text_cols:
            stmt = select("*").select_from(self.tables.denormalize_table)
            text_stmt = str(stmt.compile(connection))

            if stmt_type.text_star or stmt_type.text_cols:
                stmt = text(text_stmt)

                if stmt_type.text_cols:
                    stmt = stmt.columns(*self.tables.denormalize_table.c)
        elif stmt_type.core_select:
            stmt = select(self.tables.denormalize_table)
        else:
            stmt_type.fail()

        if use_driver_cols:
            execution_options = {"driver_column_names": True}
        else:
            execution_options = {}

        if stmt_type.driver_sql:
            row = connection.exec_driver_sql(
                text_stmt, execution_options=execution_options
            ).one()
        else:
            row = connection.execute(
                stmt,
                execution_options=execution_options,
            ).one()

        if (
            stmt_type.core_select and not use_driver_cols
        ) or not testing.requires.denormalized_names.enabled:
            self._assert_row_mapping(
                row,
                {
                    "id": 1,
                    "all_lowercase": 5,
                    "ALL_UPPERCASE": 6,
                    "MixedCase": 7,
                    "all_lowercase_quoted": 8,
                    "ALL_UPPERCASE_QUOTED": 9,
                },
            )

        if testing.requires.denormalized_names.enabled:
            # with driver column names, raw cursor.description
            # is used.  this is clearly not useful for non-quoted names.
            if use_driver_cols:
                self._assert_row_mapping(
                    row,
                    {
                        "ID": 1,
                        "ALL_LOWERCASE": 5,
                        "ALL_UPPERCASE": 6,
                        "MixedCase": 7,
                        "all_lowercase_quoted": 8,
                        "ALL_UPPERCASE_QUOTED": 9,
                    },
                )
            else:
                if stmt_type.core_select or stmt_type.text_cols:
                    self._assert_row_mapping(
                        row,
                        {
                            "id": 1,
                            "all_lowercase": 5,
                            "ALL_UPPERCASE": 6,
                            "MixedCase": 7,
                            "all_lowercase_quoted": 8,
                            "ALL_UPPERCASE_QUOTED": 9,
                        },
                        include_cols=self.tables.denormalize_table.c,
                    )
                else:
                    self._assert_row_mapping(
                        row,
                        {
                            "id": 1,
                            "all_lowercase": 5,
                            "all_uppercase": 6,
                            "MixedCase": 7,
                            "all_lowercase_quoted": 8,
                            "all_uppercase_quoted": 9,
                        },
                        include_cols=None,
                    )

        else:
            self._assert_row_mapping(
                row,
                {
                    "id": 1,
                    "all_lowercase": 5,
                    "ALL_UPPERCASE": 6,
                    "MixedCase": 7,
                    "all_lowercase_quoted": 8,
                    "ALL_UPPERCASE_QUOTED": 9,
                },
                include_cols=(
                    self.tables.denormalize_table.c
                    if stmt_type.core_select or stmt_type.text_cols
                    else None
                ),
            )


class PercentSchemaNamesTest(fixtures.TablesTest):
    """tests using percent signs, spaces in table and column names.

    This didn't work for PostgreSQL / MySQL drivers for a long time
    but is now supported.

    """

    __requires__ = ("percent_schema_names",)

    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        cls.tables.percent_table = Table(
            "percent%table",
            metadata,
            Column("percent%", Integer),
            Column("spaces % more spaces", Integer),
        )
        cls.tables.lightweight_percent_table = sql.table(
            "percent%table",
            sql.column("percent%"),
            sql.column("spaces % more spaces"),
        )

    def test_single_roundtrip(self, connection):
        percent_table = self.tables.percent_table
        for params in [
            {"percent%": 5, "spaces % more spaces": 12},
            {"percent%": 7, "spaces % more spaces": 11},
            {"percent%": 9, "spaces % more spaces": 10},
            {"percent%": 11, "spaces % more spaces": 9},
        ]:
            connection.execute(percent_table.insert(), params)
        self._assert_table(connection)

    def test_executemany_roundtrip(self, connection):
        percent_table = self.tables.percent_table
        connection.execute(
            percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
        )
        connection.execute(
            percent_table.insert(),
            [
                {"percent%": 7, "spaces % more spaces": 11},
                {"percent%": 9, "spaces % more spaces": 10},
                {"percent%": 11, "spaces % more spaces": 9},
            ],
        )
        self._assert_table(connection)

    @requirements.insert_executemany_returning
    def test_executemany_returning_roundtrip(self, connection):
        percent_table = self.tables.percent_table
        connection.execute(
            percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
        )
        result = connection.execute(
            percent_table.insert().returning(
                percent_table.c["percent%"],
                percent_table.c["spaces % more spaces"],
            ),
            [
                {"percent%": 7, "spaces % more spaces": 11},
                {"percent%": 9, "spaces % more spaces": 10},
                {"percent%": 11, "spaces % more spaces": 9},
            ],
        )
        eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
        self._assert_table(connection)

    def _assert_table(self, conn):
        percent_table = self.tables.percent_table
        lightweight_percent_table = self.tables.lightweight_percent_table

        for table in (
            percent_table,
            percent_table.alias(),
            lightweight_percent_table,
            lightweight_percent_table.alias(),
        ):
            eq_(
                list(
                    conn.execute(table.select().order_by(table.c["percent%"]))
                ),
                [(5, 12), (7, 11), (9, 10), (11, 9)],
            )

            eq_(
                list(
                    conn.execute(
                        table.select()
                        .where(table.c["spaces % more spaces"].in_([9, 10]))
                        .order_by(table.c["percent%"])
                    )
                ),
                [(9, 10), (11, 9)],
            )

            row = conn.execute(
                table.select().order_by(table.c["percent%"])
            ).first()
            eq_(row._mapping["percent%"], 5)
            eq_(row._mapping["spaces % more spaces"], 12)

            eq_(row._mapping[table.c["percent%"]], 5)
            eq_(row._mapping[table.c["spaces % more spaces"]], 12)

        conn.execute(
            percent_table.update().values(
                {percent_table.c["spaces % more spaces"]: 15}
            )
        )

        eq_(
            list(
                conn.execute(
                    percent_table.select().order_by(
                        percent_table.c["percent%"]
                    )
                )
            ),
            [(5, 15), (7, 15), (9, 15), (11, 15)],
        )


class ServerSideCursorsTest(
    fixtures.TestBase, testing.AssertsExecutionResults
):
    __requires__ = ("server_side_cursors",)

    __backend__ = True

    def _is_server_side(self, cursor):
        # TODO: this is a huge issue as it prevents these tests from being
        # usable by third party dialects.
        if self.engine.dialect.driver == "psycopg2":
            return bool(cursor.name)
        elif self.engine.dialect.driver == "pymysql":
            sscursor = __import__("pymysql.cursors").cursors.SSCursor
            return isinstance(cursor, sscursor)
        elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
            return cursor.server_side
        elif self.engine.dialect.driver == "mysqldb":
            sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
            return isinstance(cursor, sscursor)
        elif self.engine.dialect.driver == "mariadbconnector":
            return not cursor.buffered
        elif self.engine.dialect.driver == "mysqlconnector":
            return "buffered" not in type(cursor).__name__.lower()
        elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
            return cursor.server_side
        elif self.engine.dialect.driver == "pg8000":
            return getattr(cursor, "server_side", False)
        elif self.engine.dialect.driver == "psycopg":
            return bool(getattr(cursor, "name", False))
        elif self.engine.dialect.driver == "oracledb":
            return getattr(cursor, "server_side", False)
        else:
            return False

    def _fixture(self, server_side_cursors):
        if server_side_cursors:
            with testing.expect_deprecated(
                "The create_engine.server_side_cursors parameter is "
                "deprecated and will be removed in a future release.  "
                "Please use the Connection.execution_options.stream_results "
                "parameter."
            ):
                self.engine = engines.testing_engine(
                    options={"server_side_cursors": server_side_cursors}
                )
        else:
            self.engine = engines.testing_engine(
                options={"server_side_cursors": server_side_cursors}
            )
        return self.engine

    def stringify(self, str_):
        return re.compile(r"SELECT (\d+)", re.I).sub(
            lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
        )

    @testing.combinations(
        ("global_string", True, lambda stringify: stringify("select 1"), True),
        (
            "global_text",
            True,
            lambda stringify: text(stringify("select 1")),
            True,
        ),
        ("global_expr", True, select(1), True),
        (
            "global_off_explicit",
            False,
            lambda stringify: text(stringify("select 1")),
            False,
        ),
        (
            "stmt_option",
            False,
            select(1).execution_options(stream_results=True),
            True,
        ),
        (
            "stmt_option_disabled",
            True,
            select(1).execution_options(stream_results=False),
            False,
        ),
        ("for_update_expr", True, select(1).with_for_update(), True),
        # TODO: need a real requirement for this, or dont use this test
        (
            "for_update_string",
            True,
            lambda stringify: stringify("SELECT 1 FOR UPDATE"),
            True,
            testing.skip_if(["sqlite", "mssql"]),
        ),
        (
            "text_no_ss",
            False,
            lambda stringify: text(stringify("select 42")),
            False,
        ),
        (
            "text_ss_option",
            False,
            lambda stringify: text(stringify("select 42")).execution_options(
                stream_results=True
            ),
            True,
        ),
        id_="iaaa",
        argnames="engine_ss_arg, statement, cursor_ss_status",
    )
    def test_ss_cursor_status(
        self, engine_ss_arg, statement, cursor_ss_status
    ):
        engine = self._fixture(engine_ss_arg)
        with engine.begin() as conn:
            if callable(statement):
                statement = testing.resolve_lambda(
                    statement, stringify=self.stringify
                )

            if isinstance(statement, str):
                result = conn.exec_driver_sql(statement)
            else:
                result = conn.execute(statement)
            eq_(self._is_server_side(result.cursor), cursor_ss_status)
            result.close()

    def test_conn_option(self):
        engine = self._fixture(False)

        with engine.connect() as conn:
            # should be enabled for this one
            result = conn.execution_options(
                stream_results=True
            ).exec_driver_sql(self.stringify("select 1"))
            assert self._is_server_side(result.cursor)

            # the connection has autobegun, which means at the end of the
            # block, we will roll back, which on MySQL at least will fail
            # with "Commands out of sync" if the result set
            # is not closed, so we close it first.
            #
            # fun fact!  why did we not have this result.close() in this test
            # before 2.0? don't we roll back in the connection pool
            # unconditionally? yes!  and in fact if you run this test in 1.4
            # with stdout shown, there is in fact "Exception during reset or
            # similar" with "Commands out sync" emitted a warning!  2.0's
            # architecture finds and fixes what was previously an expensive
            # silent error condition.
            result.close()

    def test_stmt_enabled_conn_option_disabled(self):
        engine = self._fixture(False)

        s = select(1).execution_options(stream_results=True)

        with engine.connect() as conn:
            # not this one
            result = conn.execution_options(stream_results=False).execute(s)
            assert not self._is_server_side(result.cursor)

    def test_aliases_and_ss(self):
        engine = self._fixture(False)
        s1 = (
            select(sql.literal_column("1").label("x"))
            .execution_options(stream_results=True)
            .subquery()
        )

        # options don't propagate out when subquery is used as a FROM clause
        with engine.begin() as conn:
            result = conn.execute(s1.select())
            assert not self._is_server_side(result.cursor)
            result.close()

        s2 = select(1).select_from(s1)
        with engine.begin() as conn:
            result = conn.execute(s2)
            assert not self._is_server_side(result.cursor)
            result.close()

    def test_roundtrip_fetchall(self, metadata):
        md = self.metadata

        engine = self._fixture(True)
        test_table = Table(
            "test_table",
            md,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("data", String(50)),
        )

        with engine.begin() as connection:
            test_table.create(connection, checkfirst=True)
            connection.execute(test_table.insert(), dict(data="data1"))
            connection.execute(test_table.insert(), dict(data="data2"))
            eq_(
                connection.execute(
                    test_table.select().order_by(test_table.c.id)
                ).fetchall(),
                [(1, "data1"), (2, "data2")],
            )
            connection.execute(
                test_table.update()
                .where(test_table.c.id == 2)
                .values(data=test_table.c.data + " updated")
            )
            eq_(
                connection.execute(
                    test_table.select().order_by(test_table.c.id)
                ).fetchall(),
                [(1, "data1"), (2, "data2 updated")],
            )
            connection.execute(test_table.delete())
            eq_(
                connection.scalar(
                    select(func.count("*")).select_from(test_table)
                ),
                0,
            )

    def test_roundtrip_fetchmany(self, metadata):
        md = self.metadata

        engine = self._fixture(True)
        test_table = Table(
            "test_table",
            md,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("data", String(50)),
        )

        with engine.begin() as connection:
            test_table.create(connection, checkfirst=True)
            connection.execute(
                test_table.insert(),
                [dict(data="data%d" % i) for i in range(1, 20)],
            )

            result = connection.execute(
                test_table.select().order_by(test_table.c.id)
            )

            eq_(
                result.fetchmany(5),
                [(i, "data%d" % i) for i in range(1, 6)],
            )
            eq_(
                result.fetchmany(10),
                [(i, "data%d" % i) for i in range(6, 16)],
            )
            eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
